import torch

from tools import GetConfusionMatrix


def test(model, ssq_loader):
    num1 = 0
    num2 = 0
    for data_lists, target_lists in ssq_loader:
        model.eval()
        pred_lists = model(data_lists)
        pred_target = torch.max(pred_lists, 1)[1]
        true_target = torch.max(target_lists, 1)[1]
        tmp_num1 = GetConfusionMatrix(pred_target, true_target)
        tmp_num2 = pred_target.numel()
        num1 += tmp_num1
        num2 += tmp_num2

    print(num1 / num2)
